Skip to content

Conversation

@codeflash-ai
Copy link
Contributor

@codeflash-ai codeflash-ai bot commented Sep 27, 2025

⚡️ This pull request contains optimizations for PR #769

If you approve this dependent PR, these changes will be merged into the original PR branch clean-async-branch.

This PR will be automatically closed if the original PR is merged.


📄 39% (0.39x) speedup for get_first_top_level_function_or_method_ast in codeflash/code_utils/static_analysis.py

⏱️ Runtime : 1.43 milliseconds 1.03 milliseconds (best of 61 runs)

📝 Explanation and details

The optimized code achieves a 38% speedup through several key micro-optimizations in AST traversal:

Primary optimizations:

  1. Reduced tuple allocation overhead: Moving skip_types = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef) to a local variable eliminates repeated tuple construction on each function call (128 calls show 0.5% overhead vs previous inline tuple creation).

  2. Improved iterator efficiency: Converting ast.iter_child_nodes(node) to list(ast.iter_child_nodes(node)) upfront provides better cache locality and eliminates generator overhead during iteration, though this comes with a memory trade-off.

  3. Optimized control flow: Restructuring the isinstance checks to handle the common case (finding matching object_type) first, then using early continue statements to skip unnecessary processing, reduces the total number of isinstance calls from ~14,000 to ~11,000.

  4. Eliminated walrus operator complexity: Simplifying the class_node assignment in get_first_top_level_function_or_method_ast removes the complex conditional expression, making the code path more predictable.

Performance characteristics:

  • The optimizations are most effective for large-scale test cases with many classes/functions (500+ nodes), where the reduced overhead per iteration compounds significantly
  • Basic test cases see modest improvements since the overhead reduction is less impactful on smaller AST trees
  • The memory trade-off of list conversion is worthwhile because AST child node lists are typically small and the improved iteration speed outweighs the memory cost

The line profiler shows the optimized version spends more time in the initial list conversion (49.9% vs 46% in the original iterator), but this is offset by faster subsequent processing of the child nodes.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 68 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
from __future__ import annotations

import ast
from typing import TypeVar

# imports
import pytest  # used for our unit tests
from codeflash.code_utils.static_analysis import \
    get_first_top_level_function_or_method_ast
from codeflash.models.models import FunctionParent

# unit tests

# Helper to parse code and call the function
def get_func_ast(code: str, func_name: str, parents=None):
    node = ast.parse(code)
    return get_first_top_level_function_or_method_ast(func_name, parents or [], node)

# Basic Test Cases

def test_basic_top_level_function():
    """Test finding a top-level function."""
    code = "def foo(): pass"
    func_ast = get_func_ast(code, "foo")

def test_basic_top_level_async_function():
    """Test finding a top-level async function."""
    code = "async def foo(): pass"
    func_ast = get_func_ast(code, "foo")

def test_basic_method_in_class():
    """Test finding a method inside a class."""
    code = "class Bar:\n    def foo(self): pass"
    parents = [FunctionParent(name="Bar", type="ClassDef")]
    func_ast = get_func_ast(code, "foo", parents)

def test_basic_async_method_in_class():
    """Test finding an async method inside a class."""
    code = "class Bar:\n    async def foo(self): pass"
    parents = [FunctionParent(name="Bar", type="ClassDef")]
    func_ast = get_func_ast(code, "foo", parents)

def test_basic_multiple_functions():
    """Test multiple top-level functions; should find the first one."""
    code = "def foo(): pass\ndef foo(): pass"
    func_ast = get_func_ast(code, "foo")

def test_basic_multiple_methods_in_class():
    """Test multiple methods in a class; should find the first one."""
    code = "class Bar:\n    def foo(self): pass\n    def foo(self): pass"
    parents = [FunctionParent(name="Bar", type="ClassDef")]
    func_ast = get_func_ast(code, "foo", parents)

# Edge Test Cases

def test_edge_no_function_found():
    """Test when function is not present."""
    code = "def bar(): pass"
    func_ast = get_func_ast(code, "foo")

def test_edge_no_method_found_in_class():
    """Test when method is not present in class."""
    code = "class Bar:\n    def bar(self): pass"
    parents = [FunctionParent(name="Bar", type="ClassDef")]
    func_ast = get_func_ast(code, "foo", parents)

def test_edge_nested_function_not_top_level():
    """Test that nested functions are not returned."""
    code = "def foo():\n    def bar(): pass"
    func_ast = get_func_ast(code, "bar")

def test_edge_nested_method_not_top_level():
    """Test that nested methods in class are not returned."""
    code = "class Bar:\n    def foo(self):\n        def bar(): pass"
    parents = [FunctionParent(name="Bar", type="ClassDef")]
    func_ast = get_func_ast(code, "bar", parents)

def test_edge_function_in_inner_class_not_found():
    """Test that method in inner class is not found if parent is outer class."""
    code = "class Outer:\n    class Inner:\n        def foo(self): pass"
    parents = [FunctionParent(name="Outer", type="ClassDef")]
    func_ast = get_func_ast(code, "foo", parents)

def test_edge_method_with_wrong_parent_type():
    """Test that method is not found if parent type is not ClassDef."""
    code = "class Bar:\n    def foo(self): pass"
    parents = [FunctionParent(name="Bar", type="FunctionDef")]
    func_ast = get_func_ast(code, "foo", parents)

def test_edge_function_with_same_name_in_different_classes():
    """Test that correct parent is used."""
    code = (
        "class Bar:\n    def foo(self): pass\n"
        "class Baz:\n    def foo(self): pass"
    )
    parents = [FunctionParent(name="Baz", type="ClassDef")]
    func_ast = get_func_ast(code, "foo", parents)

def test_edge_function_with_same_name_top_level_and_in_class():
    """Test that top-level function is found if no parent, otherwise method in class."""
    code = (
        "def foo(): pass\n"
        "class Bar:\n    def foo(self): pass"
    )
    # Top-level
    func_ast = get_func_ast(code, "foo")
    # In class
    parents = [FunctionParent(name="Bar", type="ClassDef")]
    func_ast = get_func_ast(code, "foo", parents)

def test_edge_async_and_sync_function_same_name():
    """Test that sync function is preferred over async if both exist."""
    code = "def foo(): pass\nasync def foo(): pass"
    func_ast = get_func_ast(code, "foo")

def test_edge_async_and_sync_method_same_name_in_class():
    """Test that sync method is preferred over async if both exist in class."""
    code = "class Bar:\n    def foo(self): pass\n    async def foo(self): pass"
    parents = [FunctionParent(name="Bar", type="ClassDef")]
    func_ast = get_func_ast(code, "foo", parents)

def test_edge_empty_code():
    """Test with empty code."""
    code = ""
    func_ast = get_func_ast(code, "foo")

def test_edge_function_with_decorator():
    """Test function with a decorator."""
    code = "@staticmethod\ndef foo(): pass"
    func_ast = get_func_ast(code, "foo")

def test_edge_method_with_decorator_in_class():
    """Test method with a decorator inside class."""
    code = "class Bar:\n    @staticmethod\n    def foo(): pass"
    parents = [FunctionParent(name="Bar", type="ClassDef")]
    func_ast = get_func_ast(code, "foo", parents)

def test_edge_class_with_same_name_as_function():
    """Test that class with same name as function does not interfere."""
    code = "def foo(): pass\nclass foo: pass"
    func_ast = get_func_ast(code, "foo")

def test_edge_function_with_args():
    """Test function with arguments."""
    code = "def foo(a, b=1): pass"
    func_ast = get_func_ast(code, "foo")

def test_edge_method_with_args_in_class():
    """Test method with arguments inside class."""
    code = "class Bar:\n    def foo(self, a, b=1): pass"
    parents = [FunctionParent(name="Bar", type="ClassDef")]
    func_ast = get_func_ast(code, "foo", parents)

def test_edge_function_with_docstring():
    """Test function with docstring."""
    code = 'def foo():\n    """Docstring"""\n    pass'
    func_ast = get_func_ast(code, "foo")

def test_edge_method_with_docstring_in_class():
    """Test method with docstring inside class."""
    code = 'class Bar:\n    def foo(self):\n        """Docstring"""\n        pass'
    parents = [FunctionParent(name="Bar", type="ClassDef")]
    func_ast = get_func_ast(code, "foo", parents)

def test_edge_function_with_inner_class():
    """Test that function inside inner class is not found as top-level."""
    code = "class Bar:\n    class Inner:\n        def foo(self): pass"
    parents = [FunctionParent(name="Bar", type="ClassDef")]
    func_ast = get_func_ast(code, "foo", parents)

def test_edge_function_with_inner_function():
    """Test that inner function is not found as top-level."""
    code = "def outer():\n    def foo(): pass"
    func_ast = get_func_ast(code, "foo")

def test_edge_class_with_method_and_inner_function():
    """Test that only top-level method in class is found, not inner function."""
    code = "class Bar:\n    def foo(self):\n        def inner(): pass"
    parents = [FunctionParent(name="Bar", type="ClassDef")]
    func_ast = get_func_ast(code, "foo", parents)

# Large Scale Test Cases

def test_large_many_top_level_functions():
    """Test with many top-level functions, only first should be found."""
    code = "\n".join([f"def foo(): pass" if i == 0 else f"def bar{i}(): pass" for i in range(500)])
    func_ast = get_func_ast(code, "foo")

def test_large_many_classes_with_same_method_name():
    """Test with many classes having same method name, only correct parent should be found."""
    code = "\n".join([
        f"class Bar{i}:\n    def foo(self): pass"
        for i in range(500)
    ])
    parents = [FunctionParent(name="Bar123", type="ClassDef")]
    func_ast = get_func_ast(code, "foo", parents)

def test_large_many_methods_in_one_class():
    """Test with one class containing many methods, only first should be found."""
    code = "class Bar:\n" + "\n".join([f"    def foo(self): pass" if i == 0 else f"    def bar{i}(self): pass" for i in range(500)])
    parents = [FunctionParent(name="Bar", type="ClassDef")]
    func_ast = get_func_ast(code, "foo", parents)

def test_large_many_top_level_functions_same_name():
    """Test with many top-level functions of the same name, should find the first one."""
    code = "\n".join([f"def foo(): pass" for _ in range(500)])
    func_ast = get_func_ast(code, "foo")

def test_large_many_classes_with_methods_same_name():
    """Test with many classes, each having many methods of the same name, should find the first in correct class."""
    code = "\n".join([
        f"class Bar{i}:\n" + "\n".join([f"    def foo(self): pass" for _ in range(5)])
        for i in range(200)
    ])
    parents = [FunctionParent(name="Bar123", type="ClassDef")]
    func_ast = get_func_ast(code, "foo", parents)

def test_large_deeply_nested_classes_and_functions():
    """Test that deeply nested functions are not found as top-level."""
    code = "def foo():\n" + "\n".join(["    def bar{}(): pass".format(i) for i in range(500)])
    func_ast = get_func_ast(code, "bar499")

def test_large_deeply_nested_methods_in_class():
    """Test that deeply nested methods in class are not found as top-level."""
    code = "class Bar:\n    def foo(self):\n" + "\n".join(["        def bar{}(): pass".format(i) for i in range(500)])
    parents = [FunctionParent(name="Bar", type="ClassDef")]
    func_ast = get_func_ast(code, "bar499", parents)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
from __future__ import annotations

import ast
# Simulate codeflash.models.models.FunctionParent for testing
from dataclasses import dataclass
from typing import TypeVar

# imports
import pytest  # used for our unit tests
from codeflash.code_utils.static_analysis import \
    get_first_top_level_function_or_method_ast


@dataclass
class FunctionParent:
    name: str
    type: str  # e.g., "ClassDef"
from codeflash.code_utils.static_analysis import \
    get_first_top_level_function_or_method_ast

# ----------- UNIT TESTS ---------------

# Helper to parse code and return AST
def parse_code(code: str) -> ast.Module:
    return ast.parse(code)

# BASIC TESTS

def test_basic_top_level_function_found():
    # Test finding a basic top-level function
    code = """
def foo():
    pass
"""
    tree = parse_code(code)
    codeflash_output = get_first_top_level_function_or_method_ast("foo", [], tree); node = codeflash_output

def test_basic_top_level_async_function_found():
    # Test finding a top-level async function
    code = """
async def bar():
    pass
"""
    tree = parse_code(code)
    codeflash_output = get_first_top_level_function_or_method_ast("bar", [], tree); node = codeflash_output

def test_basic_class_method_found():
    # Test finding a method inside a class
    code = """
class MyClass:
    def method(self):
        pass
"""
    tree = parse_code(code)
    parent = [FunctionParent(name="MyClass", type="ClassDef")]
    codeflash_output = get_first_top_level_function_or_method_ast("method", parent, tree); node = codeflash_output

def test_basic_class_async_method_found():
    # Test finding an async method inside a class
    code = """
class MyClass:
    async def amethod(self):
        pass
"""
    tree = parse_code(code)
    parent = [FunctionParent(name="MyClass", type="ClassDef")]
    codeflash_output = get_first_top_level_function_or_method_ast("amethod", parent, tree); node = codeflash_output

def test_basic_function_not_found():
    # Test that None is returned if function is not present
    code = """
def foo():
    pass
"""
    tree = parse_code(code)
    codeflash_output = get_first_top_level_function_or_method_ast("bar", [], tree); node = codeflash_output

def test_basic_method_not_found_in_class():
    # Test that None is returned if method is not present in class
    code = """
class MyClass:
    def foo(self):
        pass
"""
    tree = parse_code(code)
    parent = [FunctionParent(name="MyClass", type="ClassDef")]
    codeflash_output = get_first_top_level_function_or_method_ast("bar", parent, tree); node = codeflash_output

# EDGE CASES

def test_function_with_same_name_in_class_and_top_level():
    # Should find top-level function if parents is empty
    code = """
def foo():
    pass

class MyClass:
    def foo(self):
        pass
"""
    tree = parse_code(code)
    codeflash_output = get_first_top_level_function_or_method_ast("foo", [], tree); node = codeflash_output
    # Should find class method if parent specified
    parent = [FunctionParent(name="MyClass", type="ClassDef")]
    codeflash_output = get_first_top_level_function_or_method_ast("foo", parent, tree); node2 = codeflash_output

def test_nested_class_method_not_found():
    # Should not find method in nested class if parent is top-level class
    code = """
class Outer:
    class Inner:
        def foo(self):
            pass
"""
    tree = parse_code(code)
    parent = [FunctionParent(name="Outer", type="ClassDef")]
    codeflash_output = get_first_top_level_function_or_method_ast("foo", parent, tree); node = codeflash_output

def test_function_with_decorator():
    # Should find function even if it has decorators
    code = """
@staticmethod
def foo():
    pass
"""
    tree = parse_code(code)
    codeflash_output = get_first_top_level_function_or_method_ast("foo", [], tree); node = codeflash_output

def test_class_method_with_decorator():
    # Should find decorated method in class
    code = """
class MyClass:
    @classmethod
    def foo(cls):
        pass
"""
    tree = parse_code(code)
    parent = [FunctionParent(name="MyClass", type="ClassDef")]
    codeflash_output = get_first_top_level_function_or_method_ast("foo", parent, tree); node = codeflash_output

def test_function_in_if_block_not_found():
    # Should not find function defined inside an if block at top level
    code = """
if True:
    def foo():
        pass
"""
    tree = parse_code(code)
    codeflash_output = get_first_top_level_function_or_method_ast("foo", [], tree); node = codeflash_output

def test_function_in_nested_function_not_found():
    # Should not find function nested inside another function
    code = """
def outer():
    def foo():
        pass
"""
    tree = parse_code(code)
    codeflash_output = get_first_top_level_function_or_method_ast("foo", [], tree); node = codeflash_output

def test_class_with_no_methods():
    # Should return None for method in class with no methods
    code = """
class MyClass:
    pass
"""
    tree = parse_code(code)
    parent = [FunctionParent(name="MyClass", type="ClassDef")]
    codeflash_output = get_first_top_level_function_or_method_ast("foo", parent, tree); node = codeflash_output

def test_class_not_found():
    # Should return None if class is not present
    code = """
def foo():
    pass
"""
    tree = parse_code(code)
    parent = [FunctionParent(name="MissingClass", type="ClassDef")]
    codeflash_output = get_first_top_level_function_or_method_ast("foo", parent, tree); node = codeflash_output

def test_multiple_functions_same_name_top_level():
    # Should find the first top-level function with the given name
    code = """
def foo():
    pass

def foo():
    pass
"""
    tree = parse_code(code)
    codeflash_output = get_first_top_level_function_or_method_ast("foo", [], tree); node = codeflash_output

def test_multiple_methods_same_name_in_class():
    # Should find the first method with given name in class
    code = """
class MyClass:
    def foo(self):
        pass
    def foo(self):
        pass
"""
    tree = parse_code(code)
    parent = [FunctionParent(name="MyClass", type="ClassDef")]
    codeflash_output = get_first_top_level_function_or_method_ast("foo", parent, tree); node = codeflash_output
    # Should be the first method in the class body
    class_node = tree.body[0]


def lambda():
    pass
"""
    tree = parse_code(code)
    codeflash_output = get_first_top_level_function_or_method_ast("lambda", [], tree); node = codeflash_output


def test_function_with_unicode_name():
    # Should handle function with unicode name
    code = """
def café():
    pass
"""
    tree = parse_code(code)
    codeflash_output = get_first_top_level_function_or_method_ast("café", [], tree); node = codeflash_output

def test_class_with_unicode_name():
    # Should handle class with unicode name
    code = """
class Café:
    def foo(self):
        pass
"""
    tree = parse_code(code)
    parent = [FunctionParent(name="Café", type="ClassDef")]
    codeflash_output = get_first_top_level_function_or_method_ast("foo", parent, tree); node = codeflash_output

def test_function_with_non_ascii_name():
    # Should handle function with non-ascii name
    code = """
def привет():
    pass
"""
    tree = parse_code(code)
    codeflash_output = get_first_top_level_function_or_method_ast("привет", [], tree); node = codeflash_output

# LARGE SCALE TESTS

def test_large_number_of_functions():
    # Test performance with many top-level functions
    code = "\n".join([f"def func{i}(): pass" for i in range(500)])
    tree = parse_code(code)
    # Should find func499
    codeflash_output = get_first_top_level_function_or_method_ast("func499", [], tree); node = codeflash_output
    # Should find func0
    codeflash_output = get_first_top_level_function_or_method_ast("func0", [], tree); node = codeflash_output
    # Should return None for non-existent function
    codeflash_output = get_first_top_level_function_or_method_ast("func500", [], tree); node = codeflash_output

def test_large_number_of_classes_and_methods():
    # Test performance with many classes and methods
    code = "\n".join([
        f"class Class{i}:\n    def method{i}(self): pass"
        for i in range(500)
    ])
    tree = parse_code(code)
    # Should find method499 in Class499
    parent = [FunctionParent(name="Class499", type="ClassDef")]
    codeflash_output = get_first_top_level_function_or_method_ast("method499", parent, tree); node = codeflash_output
    # Should return None for non-existent method
    parent = [FunctionParent(name="Class500", type="ClassDef")]
    codeflash_output = get_first_top_level_function_or_method_ast("method500", parent, tree); node = codeflash_output

def test_large_number_of_methods_in_one_class():
    # Test performance with many methods in a single class
    methods = "\n".join([f"    def method{i}(self): pass" for i in range(500)])
    code = f"""
class BigClass:
{methods}
"""
    tree = parse_code(code)
    parent = [FunctionParent(name="BigClass", type="ClassDef")]
    # Should find method0
    codeflash_output = get_first_top_level_function_or_method_ast("method0", parent, tree); node = codeflash_output
    # Should find method499
    codeflash_output = get_first_top_level_function_or_method_ast("method499", parent, tree); node = codeflash_output
    # Should return None for non-existent method
    codeflash_output = get_first_top_level_function_or_method_ast("method500", parent, tree); node = codeflash_output

def test_large_number_of_classes_with_same_method_name():
    # Test with many classes all having the same method name
    code = "\n".join([
        f"class Class{i}:\n    def foo(self): pass"
        for i in range(500)
    ])
    tree = parse_code(code)
    # Should find foo in Class0
    parent = [FunctionParent(name="Class0", type="ClassDef")]
    codeflash_output = get_first_top_level_function_or_method_ast("foo", parent, tree); node = codeflash_output
    # Should find foo in Class499
    parent = [FunctionParent(name="Class499", type="ClassDef")]
    codeflash_output = get_first_top_level_function_or_method_ast("foo", parent, tree); node = codeflash_output

def test_large_scale_function_and_method_mix():
    # Large module with mixed functions and classes
    code = "\n".join(
        [f"def func{i}(): pass" for i in range(250)] +
        [f"class Class{i}:\n    def method{i}(self): pass" for i in range(250)]
    )
    tree = parse_code(code)
    # Find last function
    codeflash_output = get_first_top_level_function_or_method_ast("func249", [], tree); node = codeflash_output
    # Find last class method
    parent = [FunctionParent(name="Class249", type="ClassDef")]
    codeflash_output = get_first_top_level_function_or_method_ast("method249", parent, tree); node = codeflash_output
    # Non-existent function/method
    codeflash_output = get_first_top_level_function_or_method_ast("func250", [], tree); node = codeflash_output
    parent = [FunctionParent(name="Class250", type="ClassDef")]
    codeflash_output = get_first_top_level_function_or_method_ast("method250", parent, tree); node = codeflash_output
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-pr769-2025-09-27T03.15.10 and push.

Codeflash

The optimized code achieves a 38% speedup through several key micro-optimizations in AST traversal:

**Primary optimizations:**
1. **Reduced tuple allocation overhead**: Moving `skip_types = (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)` to a local variable eliminates repeated tuple construction on each function call (128 calls show 0.5% overhead vs previous inline tuple creation).

2. **Improved iterator efficiency**: Converting `ast.iter_child_nodes(node)` to `list(ast.iter_child_nodes(node))` upfront provides better cache locality and eliminates generator overhead during iteration, though this comes with a memory trade-off.

3. **Optimized control flow**: Restructuring the isinstance checks to handle the common case (finding matching object_type) first, then using early `continue` statements to skip unnecessary processing, reduces the total number of isinstance calls from ~14,000 to ~11,000.

4. **Eliminated walrus operator complexity**: Simplifying the class_node assignment in `get_first_top_level_function_or_method_ast` removes the complex conditional expression, making the code path more predictable.

**Performance characteristics:**
- The optimizations are most effective for **large-scale test cases** with many classes/functions (500+ nodes), where the reduced overhead per iteration compounds significantly
- **Basic test cases** see modest improvements since the overhead reduction is less impactful on smaller AST trees
- The memory trade-off of list conversion is worthwhile because AST child node lists are typically small and the improved iteration speed outweighs the memory cost

The line profiler shows the optimized version spends more time in the initial list conversion (49.9% vs 46% in the original iterator), but this is offset by faster subsequent processing of the child nodes.
@codeflash-ai codeflash-ai bot added the ⚡️ codeflash Optimization PR opened by Codeflash AI label Sep 27, 2025
@KRRT7 KRRT7 closed this Sep 27, 2025
@codeflash-ai codeflash-ai bot deleted the codeflash/optimize-pr769-2025-09-27T03.15.10 branch September 27, 2025 03:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant